"""

## temperature = 1 / tau_denomi

python -m brainscuba \
    --subject_names subj01 subj02 subj05 subj07 \
    --atlasname streams floc-faces floc-places floc-bodies floc-words \
    --betas_norm \
    --modality image \
    --modality_hparam default \
    --model_name CLIP-ViT-B-32 \
    --reduce_dims default 0 \
    --dataset_name OpenImages \
    --max_samples full \
    --dataset_path ./data/OpenImages/frames_518x518px \
    --voxel_selection full 0 \
    --layer_selection layerLast \
    --tau_denomi 150 \
    --device cpu

"""

import os
import numpy as np
from tqdm import tqdm
import argparse
from utils.utils import search_best_layer, load_frames, make_filename, create_volume_index_and_weight_map
from concurrent.futures import ThreadPoolExecutor, as_completed
import torch
import math
import cupy as cp
import atexit
import signal
from nsd_access import NSDAccess
from CLIP_prefix_caption.notebooks.clip_prefix_captioning_inference import Voxel_Captioner
import cortex


def scaled_dot_product_attention(query, key, value, scale=None):
    L, S = query.size(-2), key.size(-2)
    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale

    # Cosine similarity between query and key
    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    # attn_weight = attn_weight - torch.max(attn_weight, dim=-1, keepdim=True)[0]
    #debug

    # Apply softmax across the last dimension (image dimension)
    attn_weight = torch.softmax(attn_weight, dim=0)
    print(f"Max of attn_weight: {attn_weight.max()}")
    # Calculate the projected weight
    norm_factors = torch.norm(value, p=2, dim=-1, keepdim=True)
    W_proj_dir = torch.sum(attn_weight @ (value / norm_factors), dim=0)
    W_proj_weight = torch.sum(attn_weight @ norm_factors, dim=0)
    W_proj = W_proj_weight * W_proj_dir

    return W_proj, attn_weight

def compute_W_proj_with_attention(W_orig, embeddings_norm, embeddings, tau_denomi):
    """
    W_orig: shape (1, M) - the original voxel weight vector (as torch.Tensor)
    embeddings: shape (K, M) - the embeddings for K images, each of dimension M (as torch.Tensor)
    tau: float - the temperature parameter for softmax

    Returns:
    W_proj: shape (1, M) - the projected voxel weight vector
    """
    # Convert tau to a scale factor for attention

    tau = 1 / tau_denomi
    print(f"Tau: {tau}")
    scale = 1 / tau
    
    # Compute the projected weight using scaled dot-product attention
    print(f"Shape of W_orig: {W_orig.shape}")
    print(f"Shape of embeddings: {embeddings.shape}")
    W_proj, attn_weight = scaled_dot_product_attention(W_orig, embeddings_norm, embeddings, scale=scale)
    
    return W_proj.squeeze(0).numpy(), attn_weight.numpy()

def compute_W_proj_in_batches(W_orig, embeddings_norm, embeddings, tau, batch_size):
    """
    W_orig: shape (1, M) - the original voxel weight vector (as torch.Tensor)
    embeddings: shape (K, M) - the embeddings for K images, each of dimension M (as torch.Tensor)
    tau: float - the temperature parameter for softmax
    batch_size: int - the number of embeddings to process in one batch

    Returns:
    W_proj: shape (1, M) - the projected voxel weight vector
    """
    num_images = embeddings.shape[0]
    W_proj_batches = []

    for i in range(0, num_images, batch_size):
        batch_embeddings = embeddings[i:i + batch_size]
        batch_embeddings_norm = batch_embeddings[i: i + batch_size]
        W_proj_batch, attn_weight = compute_W_proj_with_attention(W_orig, batch_embeddings_norm, batch_embeddings, tau)
        W_proj_batches.append(W_proj_batch)

    # Combine all the batches into a single projected weight vector
    W_proj = np.mean(W_proj_batches, axis=0)
    
    return W_proj, attn_weight

def cleanup_temp_file(temp_file):
    if os.path.exists(temp_file):
        os.remove(temp_file)
        print(f"Temporary file {temp_file} deleted.")

def process_voxel(voxel_index, weight_index, subject_name, dataset_name, args, modality, modality_hparam, model_name, expanded_stims, 
                  layer_weight, use_gpu, captioner, tau_denomi=150, betanorm=False):
    print(f"voxel_index: {voxel_index}")
    vindex_pad = str(voxel_index).zfill(6)
    resp_save_path = f"./data/nsd/insilico/{subject_name}/{dataset_name}_{args.max_samples}/{modality}/{modality_hparam}/{model_name}_{make_filename(args.reduce_dims[0:2])}/whole/voxel{vindex_pad}"
    os.makedirs(resp_save_path, exist_ok=True)
        
    save_tau_denomi_name = str(tau_denomi)

    batch_size = expanded_stims.shape[0]
    batch_name = "full"
    
    if betanorm:
        W_proj_i_save_path = f"{resp_save_path}/{model_name}_W_proj_dpa_taudenomi{save_tau_denomi_name}_betanorm.npy"
    else:
        W_proj_i_save_path = f"{resp_save_path}/{model_name}_W_proj_dpa_taudenomi{save_tau_denomi_name}.npy"
    # Check if the process has already been done
    # if os.path.exists(W_proj_i_save_path):
    #     print(f"Already processed {W_proj_i_save_path}")
    #     return
    
    # Create a temporary file to indicate that processing is happening
    temp_file = f"{W_proj_i_save_path}.tmp"
    if os.path.exists(temp_file):
        print(f"Already processing {W_proj_i_save_path}")
        return
    with open(temp_file, 'w') as f:
        f.write("Processing...")

    # Ensure temp file is deleted even if the process is interrupted
    atexit.register(cleanup_temp_file, temp_file)
    signal.signal(signal.SIGINT, lambda sig, frame: cleanup_temp_file(temp_file) or exit(1))
    signal.signal(signal.SIGTERM, lambda sig, frame: cleanup_temp_file(temp_file) or exit(1))
    
    try:
        if betanorm:
            caption_save_path = f"{resp_save_path}/caption_brainscuba_tau{tau_denomi}_betanorm.txt"
        else:
            caption_save_path = f"{resp_save_path}/caption_brainscuba_tau{tau_denomi}.txt"
        
        if os.path.exists(caption_save_path):
            print(f"Already processed {caption_save_path}")
            return

        device = torch.device(use_gpu if torch.cuda.is_available() and use_gpu == 'cuda' else 'cpu')
        torch_dtype = torch.float32
        
        W_orig_i = layer_weight[:, weight_index].reshape(1, 1, -1)
        W_orig_i = torch.tensor(W_orig_i, dtype=torch_dtype).to(device)
        e_all = torch.tensor(expanded_stims, dtype=torch_dtype).to(device)
        if betanorm:
            W_orig_i_norm = torch.norm(W_orig_i, p=2, dim=-1, keepdim=True)
            e_all_norm = torch.norm(e_all, p=2, dim=-1, keepdim=True)
            W_orig_i = W_orig_i / W_orig_i_norm
            e_all_norm = e_all / e_all_norm
        else:
            e_all_norm = e_all.copy()
        # Weight statistics
        print(f"W's mean: {W_orig_i.mean()}, std: {W_orig_i.std()}, norm: {torch.norm(W_orig_i)}")

        # embeddings' statistics
        print(f"Mean of the embs: {e_all.mean()}, std: {e_all.std()}, norm: {torch.norm(e_all[0])}")
        
        # reshape the embeddings to have the same number of dimensions as the original weight
        e_all = e_all.reshape(e_all.shape[0], -1, W_orig_i.shape[-1])

        W_proj_i, attn_weight = compute_W_proj_in_batches(W_orig_i, e_all_norm, e_all, tau_denomi, batch_size)
        # print(f"Projected wight: {W_proj_i}")
        print(f"Shape of the projected weight: {W_proj_i.shape}")
        print(f"Mean of the projected weight: {W_proj_i.mean()}, std of the projected weight: {W_proj_i.std()}, norm: {np.linalg.norm(W_proj_i)}")

        # Save the projected weight
        np.save(W_proj_i_save_path, W_proj_i)
        
        # Save the attention weight
        if betanorm:
            attn_weight_save_path = f"{resp_save_path}/{model_name}_attn_weight_ddpa_tau{save_tau_denomi_name}_betanorm.npy"
        else:
            attn_weight_save_path = f"{resp_save_path}/{model_name}_attn_weight_ddpa_tau{save_tau_denomi_name}.npy"
        np.save(attn_weight_save_path, attn_weight)

        # Make a caption from the projected weight
        caption = captioner.generate_caption(W_proj_i_save_path)
        print(caption)
        # Save caption
        with open(caption_save_path, "w") as f:
            f.write(caption)
    
    finally:
        # Cleanup temp file
        cleanup_temp_file(temp_file)



def load_and_prepare_dir(dir_name, frame_paths, stim_root_path, dataset_name, subject_name,
                         modality, modality_hparam, model_name, best_layer, reduce_dims):
    dir_stims = []
    mov_paths = []
    for frame_path in frame_paths:
        movname = os.path.basename(frame_path).replace(".txt", "").replace(".mp4", "").replace(".png", "").replace(".jpg", "").replace(".wav", "").replace(".mp3", "")
        stim_best_layer_path = os.path.join(
            stim_root_path, dataset_name, modality, modality_hparam, 
            model_name, best_layer
        )

        if reduce_dims[0] == "default":
            stim = np.load(f"{stim_best_layer_path}/{dir_name}/{movname}.npy")
            
        else:
            filename = make_filename(reduce_dims)
            if os.path.exists(f"{stim_best_layer_path}/{dir_name}/{movname}_{subject_name}_ave_{filename}.npy"):
                stim = np.load(f"{stim_best_layer_path}/{dir_name}/{movname}_{subject_name}_ave_{filename}.npy")
            else:
                stim = np.load(f"{stim_best_layer_path}/{dir_name}/{movname}.npy")
            
        if stim is not None:
            dir_stims.append(stim)
            # mov_path = f"{stim_best_layer_path}/{dir_name}/{movname_stim}"
            mov_paths.append(f"{dir_name}/{movname}")

    if dir_stims:
        dir_stims = np.array(dir_stims)
        if len(dir_stims.shape) == 3:
            dir_stims = dir_stims.squeeze()
        
        return dir_stims, mov_paths
    
    return None, None

def reduce_dimensions(dir_stims, reducer_projector):
    # チャンクの処理を一括して次元削減
    dir_stims_transformed = reducer_projector.transform(dir_stims)
    return dir_stims_transformed

def main(args):
    score_root_path = "./data/nsd/encoding"
    modality = args.modality
    modality_hparam = args.modality_hparam
    model_name = args.model_name
    file_type = args.voxel_selection[0]
    threshold = float(args.voxel_selection[1])
    nsda = NSDAccess('./data/NSD')
    use_gpu = torch.cuda.is_available() and args.device == "cuda"
    
    captioner = Voxel_Captioner()
    
    for subject_name in args.subject_names:
        print(subject_name)
        filename = make_filename(args.reduce_dims[0:2])
        if args.betas_norm:
            filename = filename + "_betanorm"

        print(f"Modality: {modality}, Modality hparams: {modality_hparam}, Feature: {model_name}, Filename: {filename}")
        # loading the selected layer per subject
        model_score_dir = f"{score_root_path}/{subject_name}/scores/{modality}/{modality_hparam}/{model_name}"
        if args.layer_selection == "best":
            target_best_cv_layer, _, _ = search_best_layer(model_score_dir, filename, select_topN="all")
        else:
            target_best_cv_layer = args.layer_selection
        print(f"Best layer: {target_best_cv_layer}")

        # Get encoding weight
        layer_path = f"{model_score_dir}/{target_best_cv_layer}"
        layer_weight = np.load(f"{layer_path}/coef_{filename}.npy")
        print(f"Shape of the layer's weight: {layer_weight.shape}")
            
        # # Create a mapping from volume indices to weight indices
        ctx_mask = cortex.db.get_mask(subject_name, "full")
        ctx_mask_flat = ctx_mask.flatten()
        ctx_mask_index = np.where(ctx_mask_flat)[0]
        # Create a mapping from volume indices to cc indices
        weight_index_map = np.full(len(ctx_mask_flat), -1, dtype=int)
        weight_index_map[ctx_mask_index] = np.arange(len(ctx_mask_index))
        
        volume_index, weight_index_map, target_top_voxels = create_volume_index_and_weight_map(
            subject_name=subject_name,
            file_type=file_type,
            threshold=threshold,
            model_score_dir=model_score_dir,
            target_best_cv_layer=target_best_cv_layer,
            filename=filename,
            nsda=nsda,
            atlasnames=args.atlasname  # args.atlasname がリストであることを想定
        )

        stim_root_path = "./data/stim_features/nsd"
        if args.reduce_dims[0] != "default":
            try:
                reducer_proj_path = f"{stim_root_path}/{modality}/{modality_hparam}/{model_name}/{target_best_cv_layer}/projector_{subject_name}_ave_{filename}.npy"
                reducer_projector = np.load(reducer_proj_path, allow_pickle=True).item()
            except:
                reducer_proj_path = f"{stim_root_path}/{modality}/{modality_hparam}/{model_name}/{target_best_cv_layer}/projector_{subject_name}_ave_{filename}.pkl"
                reducer_projector = np.load(reducer_proj_path, allow_pickle=True)
        else:
            reducer_projector = None
        print(reducer_projector)

        dataset_name = args.dataset_name

        if args.max_samples == "full":
            break_point = 100000000
        else:
            break_point = int(args.max_samples)

        if dataset_name == "OpenImages":
            frames_all = load_frames(f"{args.dataset_path}", dataset_name)
            print(f"Number of directory: {len(frames_all)}")
        
        # 刺激の読み込み
        count = 0
        all_stims = []
        all_mov_paths = []
        batch_size = 20  # バッチサイズを設定
        
        insilico_stim_root_path = f"./data/stim_features/"
        stim_best_layer_path = os.path.join(
            insilico_stim_root_path, dataset_name, modality, modality_hparam, 
            model_name, target_best_cv_layer
        )
        # 全てのdir_nameのファイルをバッチごとに処理
        dir_names = list(frames_all.keys())
        for i in range(0, len(dir_names), batch_size):
            batch_dir_names = dir_names[i:i + batch_size]
            
            with ThreadPoolExecutor() as load_executor:
                future_to_dirname = {load_executor.submit(load_and_prepare_dir, dir_name, frames_all[dir_name], insilico_stim_root_path, dataset_name, subject_name, modality, modality_hparam, model_name, target_best_cv_layer, args.reduce_dims): dir_name for dir_name in batch_dir_names}
                
                loaded_data = []
                for future in tqdm(as_completed(future_to_dirname), total=len(future_to_dirname)):
                    dir_stims, mov_paths = future.result()
                    if dir_stims is not None:
                        loaded_data.append((dir_stims, mov_paths))
                        count += len(mov_paths)
                        if count >= break_point:
                            break
            
            # Main processing code
            if args.reduce_dims[0] != "default":

                for dir_stims, mov_paths in tqdm(loaded_data):
                    if dir_stims.shape[1] != int(args.reduce_dims[1]):
                        dir_stims_transformed = reduce_dimensions(dir_stims, reducer_projector)

                        # Use ThreadPoolExecutor for parallel saving
                        with ThreadPoolExecutor() as executor:
                            futures = []
                            for idx, movname in enumerate(mov_paths):
                                save_path = f"{stim_best_layer_path}/{movname}_{subject_name}_ave_{filename}.npy"
                                futures.append(executor.submit(save_npy, save_path, dir_stims_transformed[idx]))

                            # Optional: ensure all futures are completed (for error handling)
                            for future in tqdm(futures, desc="Saving files"):
                                future.result()
                    else:
                        dir_stims_transformed = dir_stims
                        
                    all_stims.append(dir_stims_transformed)
                    all_mov_paths.extend(mov_paths)
            else:
                for dir_stims, mov_paths in loaded_data:
                    all_stims.append(dir_stims)
                    all_mov_paths.extend(mov_paths)

            if count >= break_point:
                break

        # すべてのチャンクを最終的に結合
        if all_stims:
            all_stims = np.concatenate(all_stims, axis=0)
        
        all_stims = np.array(all_stims).squeeze()
        print(f"Reduced features shape: {all_stims.shape}")
        if args.device == "cuda":
            all_stims = cp.asarray(all_stims)
            
        for vol_idx in volume_index:
            weight_index = weight_index_map[vol_idx]
            process_voxel(vol_idx, weight_index, subject_name, dataset_name, args, modality, modality_hparam, model_name,  
                          all_stims, layer_weight, use_gpu, captioner, args.tau_denomi, args.betas_norm)
                
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # Add your arguments here
    parser.add_argument(
        "--subject_names",
        nargs="*",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--atlasname",
        type=str,
        nargs="*",
        required=True,
    )
    parser.add_argument(
        "--betas_norm",
        action="store_true"
    )
    parser.add_argument(
        "--modality",
        type=str,
        required=True,
        help="Name of the modality to use."
    )
    parser.add_argument(
        "--modality_hparam",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--model_name",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--reduce_dims",
        nargs="*",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--dataset_name",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--max_samples",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--dataset_path",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--voxel_selection",
        nargs=2,
        type=str,
        required=True,
    )
    parser.add_argument(
        "--layer_selection",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--tau_denomi",
        type=float,
        default=0,
    )
    parser.add_argument(
        "--device",
        type=str,
        required=True,
        choices=["cuda", "cpu"],
        help="Device to use."
    )
    args = parser.parse_args()
    main(args)
